{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5f41a38a",
   "metadata": {},
   "source": [
    "# Tutorial 01. Block Poisson problem.\n",
    "\n",
    "In this tutorial we first solve the problem\n",
    "\n",
    "\\begin{cases}\n",
    "-u'' = f & \\text{in }\\Omega = (0, 1),\\\\\n",
    " u   = 0 & \\text{on }\\Gamma = \\{0, 1\\}\n",
    "\\end{cases}\n",
    "\n",
    "using standard (non block) `dolfinx` code.\n",
    "\n",
    "Then we use block support of `dolfinx` to solve the system\n",
    "\n",
    "\\begin{cases}\n",
    "- w_1'' - 2 w_2'' = 3 f & \\text{in }\\Omega,\\\\\n",
    "- 3 w_1'' - 4 w_2'' = 7 f & \\text{in }\\Omega\n",
    "\\end{cases}\n",
    "\n",
    "subject to\n",
    "\n",
    "\\begin{cases}\n",
    " w_1 = 0 & \\text{on }\\Gamma,\\\\\n",
    " w_2 = 0 & \\text{on }\\Gamma\n",
    "\\end{cases}\n",
    "\n",
    "By construction the solution of the system is\n",
    "$$(w_1, w_2) = (u, u)$$\n",
    "\n",
    "We then compare the obtained solutions. This tutorial serves as a reminder on how to solve a linear problem in `dolfinx`, and introduces `multiphenicsx.fem.petsc.BlockVecSubVectorWrapper`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e9af0ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dolfinx.fem\n",
    "import dolfinx.fem.petsc\n",
    "import dolfinx.mesh\n",
    "import mpi4py.MPI\n",
    "import numpy as np\n",
    "import numpy.typing\n",
    "import petsc4py.PETSc\n",
    "import ufl\n",
    "import viskex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb6c1f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import multiphenicsx.fem\n",
    "import multiphenicsx.fem.petsc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7a79c61",
   "metadata": {},
   "source": [
    "### Mesh generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49ae8c4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh = dolfinx.mesh.create_unit_interval(mpi4py.MPI.COMM_WORLD, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6be2f3e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def left(x: np.typing.NDArray[np.float64]) -> np.typing.NDArray[np.bool_]:\n",
    "    \"\"\"Identify the left boundary of the domain.\"\"\"\n",
    "    return abs(x[0] - 0.) < np.finfo(float).eps  # type: ignore[no-any-return]\n",
    "\n",
    "\n",
    "def right(x: np.typing.NDArray[np.float64]) -> np.typing.NDArray[np.bool_]:\n",
    "    \"\"\"Identify the right boundary of the domain.\"\"\"\n",
    "    return abs(x[0] - 1.) < np.finfo(float).eps  # type: ignore[no-any-return]\n",
    "\n",
    "\n",
    "left_facets = dolfinx.mesh.locate_entities_boundary(mesh, mesh.topology.dim - 1, left)\n",
    "right_facets = dolfinx.mesh.locate_entities_boundary(mesh, mesh.topology.dim - 1, right)\n",
    "boundary_facets = np.hstack((left_facets, right_facets))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e73c7796-9ff8-4866-9469-54195efefe5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create connectivities required by the rest of the code\n",
    "mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d0f035b",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0 = ufl.SpatialCoordinate(mesh)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "559ce0e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_mesh(mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1f03019",
   "metadata": {},
   "source": [
    "### Standard case (solve for $u$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3b2ed58",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_standard() -> dolfinx.fem.Function:\n",
    "    \"\"\"Run the standard case: solve for the unknown u.\"\"\"\n",
    "    # Define a function space\n",
    "    V = dolfinx.fem.functionspace(mesh, (\"Lagrange\", 2))\n",
    "    u = ufl.TrialFunction(V)\n",
    "    v = ufl.TestFunction(V)\n",
    "\n",
    "    # Define problems forms\n",
    "    a = ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx\n",
    "    f = ufl.inner(100 * ufl.sin(20 * x0), v) * ufl.dx\n",
    "\n",
    "    # Define boundary conditions\n",
    "    zero = petsc4py.PETSc.ScalarType(0)\n",
    "    bdofs_V = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundary_facets)\n",
    "    bc = dolfinx.fem.dirichletbc(zero, bdofs_V, V)\n",
    "\n",
    "    # Solve the linear system\n",
    "    u = dolfinx.fem.Function(V)\n",
    "    problem = dolfinx.fem.petsc.LinearProblem(\n",
    "        a, f, bcs=[bc], u=u,\n",
    "        petsc_options={\"ksp_type\": \"preonly\", \"pc_type\": \"lu\", \"pc_factor_mat_solver_type\": \"mumps\"})\n",
    "    problem.solve()\n",
    "    u.x.petsc_vec.ghostUpdate(addv=petsc4py.PETSc.InsertMode.INSERT, mode=petsc4py.PETSc.ScatterMode.FORWARD)\n",
    "\n",
    "    # Return the solution\n",
    "    return u  # type: ignore[no-any-return]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76f6076a",
   "metadata": {},
   "outputs": [],
   "source": [
    "u = run_standard()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff169688",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(u, \"u\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e1b01f2",
   "metadata": {},
   "source": [
    "### Block case (solve for ($w_1$, $w_2$))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "944b2102",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_block() -> tuple[dolfinx.fem.Function, dolfinx.fem.Function]:\n",
    "    \"\"\"Run the block case: solve for the unknowns (w_1, w_2).\"\"\"\n",
    "    # Define a block function space\n",
    "    V1 = dolfinx.fem.functionspace(mesh, (\"Lagrange\", 2))\n",
    "    V2 = dolfinx.fem.functionspace(mesh, (\"Lagrange\", 2))\n",
    "    (u1, u2) = (ufl.TrialFunction(V1), ufl.TrialFunction(V2))\n",
    "    (v1, v2) = (ufl.TestFunction(V1), ufl.TestFunction(V2))\n",
    "\n",
    "    # Define problem block forms\n",
    "    aa = [[1 * ufl.inner(ufl.grad(u1), ufl.grad(v1)) * ufl.dx, 2 * ufl.inner(ufl.grad(u2), ufl.grad(v1)) * ufl.dx],\n",
    "          [3 * ufl.inner(ufl.grad(u1), ufl.grad(v2)) * ufl.dx, 4 * ufl.inner(ufl.grad(u2), ufl.grad(v2)) * ufl.dx]]\n",
    "    ff = [ufl.inner(300 * ufl.sin(20 * x0), v1) * ufl.dx,\n",
    "          ufl.inner(700 * ufl.sin(20 * x0), v2) * ufl.dx]\n",
    "    aa_cpp = dolfinx.fem.form(aa)\n",
    "    ff_cpp = dolfinx.fem.form(ff)\n",
    "\n",
    "    # Define block boundary conditions\n",
    "    zero = petsc4py.PETSc.ScalarType(0)\n",
    "    bdofs_V1 = dolfinx.fem.locate_dofs_topological(V1, mesh.topology.dim - 1, boundary_facets)\n",
    "    bdofs_V2 = dolfinx.fem.locate_dofs_topological(V2, mesh.topology.dim - 1, boundary_facets)\n",
    "    bc1 = dolfinx.fem.dirichletbc(zero, bdofs_V1, V1)\n",
    "    bc2 = dolfinx.fem.dirichletbc(zero, bdofs_V2, V2)\n",
    "    bcs = [bc1, bc2]\n",
    "\n",
    "    # Assemble the block linear system\n",
    "    AA = dolfinx.fem.petsc.assemble_matrix_block(aa_cpp, bcs)\n",
    "    AA.assemble()\n",
    "    FF = dolfinx.fem.petsc.assemble_vector_block(ff_cpp, aa_cpp, bcs)\n",
    "\n",
    "    # Solve the block linear system\n",
    "    uu = dolfinx.fem.petsc.create_vector_block(ff_cpp)\n",
    "    ksp = petsc4py.PETSc.KSP()\n",
    "    ksp.create(mesh.comm)\n",
    "    ksp.setOperators(AA)\n",
    "    ksp.setType(\"preonly\")\n",
    "    ksp.getPC().setType(\"lu\")\n",
    "    ksp.getPC().setFactorSolverType(\"mumps\")\n",
    "    ksp.setFromOptions()\n",
    "    ksp.solve(FF, uu)\n",
    "    uu.ghostUpdate(addv=petsc4py.PETSc.InsertMode.INSERT, mode=petsc4py.PETSc.ScatterMode.FORWARD)\n",
    "    ksp.destroy()\n",
    "\n",
    "    # Split the block solution in components\n",
    "    u1u2 = (dolfinx.fem.Function(V1), dolfinx.fem.Function(V2))\n",
    "    with multiphenicsx.fem.petsc.BlockVecSubVectorWrapper(uu, [V1.dofmap, V2.dofmap]) as uu_wrapper:\n",
    "        for u_wrapper_local, component in zip(uu_wrapper, u1u2):\n",
    "            with component.x.petsc_vec.localForm() as component_local:\n",
    "                component_local[:] = u_wrapper_local\n",
    "\n",
    "    # Return the block solution components\n",
    "    return u1u2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7a2ee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "u1, u2 = run_block()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a5d1754",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(u1, \"u1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5accaa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(u2, \"u2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79ca048f",
   "metadata": {},
   "source": [
    "### Error computation between the different cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f414950e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_errors(\n",
    "    u1: dolfinx.fem.Function, u2: dolfinx.fem.Function, uu1: dolfinx.fem.Function, uu2: dolfinx.fem.Function\n",
    ") -> None:\n",
    "    \"\"\"Compute errors between standard and standard block cases.\"\"\"\n",
    "    u_1_norm = np.sqrt(mesh.comm.allreduce(\n",
    "        dolfinx.fem.assemble_scalar(dolfinx.fem.form(ufl.inner(ufl.grad(u1), ufl.grad(u1)) * ufl.dx)),\n",
    "        op=mpi4py.MPI.SUM))\n",
    "    u_2_norm = np.sqrt(mesh.comm.allreduce(\n",
    "        dolfinx.fem.assemble_scalar(dolfinx.fem.form(ufl.inner(ufl.grad(u2), ufl.grad(u2)) * ufl.dx)),\n",
    "        op=mpi4py.MPI.SUM))\n",
    "    err_1_norm = np.sqrt(mesh.comm.allreduce(\n",
    "        dolfinx.fem.assemble_scalar(dolfinx.fem.form(ufl.inner(ufl.grad(u1 - uu1), ufl.grad(u1 - uu1)) * ufl.dx)),\n",
    "        op=mpi4py.MPI.SUM))\n",
    "    err_2_norm = np.sqrt(mesh.comm.allreduce(\n",
    "        dolfinx.fem.assemble_scalar(dolfinx.fem.form(ufl.inner(ufl.grad(u2 - uu2), ufl.grad(u2 - uu2)) * ufl.dx)),\n",
    "        op=mpi4py.MPI.SUM))\n",
    "    print(\"  Relative error for first component is equal to\", err_1_norm / u_1_norm)\n",
    "    print(\"  Relative error for second component is equal to\", err_2_norm / u_2_norm)\n",
    "    assert np.isclose(err_1_norm / u_1_norm, 0., atol=1.e-10)\n",
    "    assert np.isclose(err_2_norm / u_2_norm, 0., atol=1.e-10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0884cd95",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Computing errors between standard and block\")\n",
    "compute_errors(u, u, u1, u2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython"
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
